/**
 * Copyright (c) 2016-2020 https://github.com/zhaohuatai
 *
 * contact z_huatai@qq.com
 *  
 */
package org.zfes.snowy.auth.shiro.realm;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.SimpleAuthenticationInfo;
import org.apache.shiro.authc.UnknownAccountException;
import org.apache.shiro.authc.UsernamePasswordToken;
import org.apache.shiro.authz.AuthorizationException;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.cache.Cache;
import org.apache.shiro.ldap.UnsupportedAuthenticationMechanismException;
import org.apache.shiro.session.Session;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.SerialSimpleByteSource;
import org.jasig.cas.client.authentication.AttributePrincipal;
import org.jasig.cas.client.validation.Assertion;
import org.jasig.cas.client.validation.Cas20ServiceTicketValidator;
import org.jasig.cas.client.validation.Saml11TicketValidator;
import org.jasig.cas.client.validation.TicketValidationException;
import org.jasig.cas.client.validation.TicketValidator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.event.EventListener;
import org.springframework.scheduling.annotation.Async;
import org.zfes.snowy.auth.AuthConsts;
import org.zfes.snowy.auth.biz.model.AuthUser;
import org.zfes.snowy.auth.manager.IAuthManager;
import org.zfes.snowy.auth.shiro.cas.SnowyCasAuthenticationException;
import org.zfes.snowy.auth.shiro.cas.SnowyCasToken;
import org.zfes.snowy.auth.shiro.event.AuthCacheClearType;
import org.zfes.snowy.auth.shiro.event.AuthChangeEvent;
import org.zfes.snowy.auth.shiro.jwt.token.JWTTokenParser;
import org.zfes.snowy.auth.shiro.jwt.token.JwtAuthToken;
import org.zfes.snowy.auth.shiro.util.JWTTokenVerifyPublicKeyUtil;
import org.zfes.snowy.auth.shiro.weichat.token.WeiChatAuthToken;
import org.zfes.snowy.core.util.ZStrUtil;

import com.google.common.collect.Lists;


public class MultiSubjectRealm extends SnowyRemoteRealm {
	protected final Logger logger = LoggerFactory.getLogger(this.getClass());
    
    private IAuthManager authManager;
    
    private String appKey;
  
    // CAS server url(example : http://host:port/cas)
    private String casServerUrlPrefix;
    
    //the application url(example : http://host:port/mycontextpath/shiro-cas)
    private String casService;
    
    private String validationProtocol="CAS";//SAML
    private TicketValidator ticketValidator;
    @Override
    public boolean supports(AuthenticationToken authcToken) {  
	     return (authcToken instanceof UsernamePasswordToken 
	    		 ||authcToken instanceof JwtAuthToken
	    		 ||authcToken instanceof WeiChatAuthToken
	    		 ||authcToken instanceof SnowyCasToken);
	}
    /**
     * 认证
     */
    @Override
    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken authcToken)
    											throws AuthenticationException {
		SimpleAuthenticationInfo authenInfo=null;
		if(authcToken instanceof UsernamePasswordToken){
			UsernamePasswordToken token = (UsernamePasswordToken) authcToken;
	    	Optional<AuthUser> userOp=authManager.loadUserByName(token.getUsername());
	    	if(!userOp.isPresent()) {  
	            throw new UnknownAccountException("用户名或密码错误");
	        }  
	    	
	    	AuthUser user=userOp.get();
	    	if(!user.getEnabled()) {
	    		 throw new AuthenticationException("用户状态错误");
	    	}
	    	authenInfo = new SimpleAuthenticationInfo(user.getAccount(),user.getPassword(), getName());
	    	authenInfo.setCredentialsSalt(new SerialSimpleByteSource(user.getSalt()));
	    	return authenInfo;
		}else if(authcToken instanceof JwtAuthToken ){
			JwtAuthToken jwtToken = (JwtAuthToken) authcToken;
			String principal=JWTTokenParser.parseUsername(jwtToken.getToken(), JWTTokenVerifyPublicKeyUtil.readJwtVerifyPublicKey());
			authenInfo=new SimpleAuthenticationInfo(principal, "",getName());
		}else if(authcToken instanceof WeiChatAuthToken ){
			//暂时认证--通过
			WeiChatAuthToken weiChatAuthToken = (WeiChatAuthToken) authcToken;
			String principal=weiChatAuthToken.getPrincipal().toString();
			String credentials=weiChatAuthToken.getCredentials().toString();
			authenInfo=new SimpleAuthenticationInfo(principal, credentials,getName());
			
		}else if(authcToken instanceof SnowyCasToken){//cas
				SnowyCasToken casToken = (SnowyCasToken) authcToken;
		      	String ticket = (String)casToken.getTicket();//.getCredentials();
		      	if(ZStrUtil.hasNoText(ticket)){
			    	  throw new SnowyCasAuthenticationException("Unable to validate ticket [" + ticket + "]");
			      }
		        try {
		        	TicketValidator ticketValidator = ensureTicketValidator();
		        	
		        	//ticket校验
		            Assertion casAssertion = ticketValidator.validate(ticket, getCasService());
		            //获取Principal
		            AttributePrincipal casPrincipal = casAssertion.getPrincipal();
		            String userId = casPrincipal.getName();
		            Optional<AuthUser> userOp=authManager.loadUserByName(userId);
			    	if(!userOp.isPresent()) {  
			            throw new UnknownAccountException("CAS认证错误，未查询到本地用户");
			        }  
			    	AuthUser user=userOp.get();
			    	authenInfo = new SimpleAuthenticationInfo(user.getAccount(),user.getPassword(), getName());
			    	authenInfo.setCredentialsSalt(new SerialSimpleByteSource(user.getSalt()));
			    	return authenInfo;
		            
		        } catch (TicketValidationException e) { 
		            throw new SnowyCasAuthenticationException("Unable to validate ticket [" + ticket + "]", e);
		        }
		}else{
			throw new UnsupportedAuthenticationMechanismException(" the target  AuthenticationToken : "+authcToken.getClass().getName()+" is not support ");
		}
		
		return authenInfo;
    }
    /**
     * 授权
     */
    @Override
    protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {
   	 List<String> roleCodeList=Lists.newArrayList();
   	 List<String> permCodeList=Lists.newArrayList();
     List<String> docAccessLevelCodeList=Lists.newArrayList();
        String userName = (String) principals.getPrimaryPrincipal();
        Optional<AuthUser> authUserOp= authManager.loadUserByName(userName);
        if(authUserOp.isPresent()){
        	roleCodeList=authManager.loadRoleCodesUserHave(authUserOp.get().getId(), appKey, true);
        	permCodeList=authManager.loadPermCodesInUserDefaultRole(authUserOp.get().getId(), appKey, true);
        }
       SimpleAuthorizationInfo authorizationInfo = new SimpleAuthorizationInfo();
       authorizationInfo.addRoles(roleCodeList);
       authorizationInfo.addStringPermissions(docAccessLevelCodeList);
       authorizationInfo.addStringPermissions(permCodeList);
       return authorizationInfo;
    }
    
    //------------------------remote------------------------------------------
    public AuthenticationInfo  remoteDoGetAuthenticationInfo(String appKey,AuthenticationToken authcToken) {
    	return doGetAuthenticationInfo(authcToken);
    }
    
    public AuthorizationInfo remoteDoGetAuthorizationInfo(String appKey,PrincipalCollection principals) {
        return doLoadAuthorizationInfo( appKey, principals);
    }
    //-------------------------remote-----------------------------------------
    private AuthorizationInfo doLoadAuthorizationInfo(String appKey,PrincipalCollection principals){
	   	 if (principals == null) {
	   		 throw new AuthorizationException("PrincipalCollection method argument cannot be null.");
	   	 }
	   	 List<String> roleCodeList=Lists.newArrayList();
	   	 List<String> permCodeList=Lists.newArrayList();
	        String userName = (String) principals.getPrimaryPrincipal();
	        Optional<AuthUser> authUserOp= authManager.loadUserByName(userName);
	        if(authUserOp.isPresent()){
	        	roleCodeList=authManager.loadRoleCodesUserHave(authUserOp.get().getId(), appKey, true);
	        	permCodeList=authManager.loadPermCodesInUserDefaultRole(authUserOp.get().getId(), appKey, true);
	        }
	       SimpleAuthorizationInfo authorizationInfo = new SimpleAuthorizationInfo();
	       authorizationInfo.addRoles(roleCodeList);
	       
	       authorizationInfo.addStringPermissions(permCodeList);
	       return authorizationInfo;
   }
    
  
    //----------------------------------------------------------

//    private static final String OR_OPERATOR = " or ";
//    private static final String AND_OPERATOR = " and ";
//    private static final String NOT_OPERATOR = "not ";
//
//    /**
//     * 支持or and not 关键词  不支持and or混用
//     *
//     * @param principals
//     * @param permission
//     * @return
//     */
//    public boolean isPermitted(PrincipalCollection principals, String permission) {
//        if (permission.contains(OR_OPERATOR)) {
//            String[] permissions = permission.split(OR_OPERATOR);
//            for (String orPermission : permissions) {
//                if (isPermittedWithNotOperator(principals, orPermission)) {
//                    return true;
//                }
//            }
//            return false;
//        } else if (permission.contains(AND_OPERATOR)) {
//            String[] permissions = permission.split(AND_OPERATOR);
//            for (String orPermission : permissions) {
//                if (!isPermittedWithNotOperator(principals, orPermission)) {
//                    return false;
//                }
//            }
//            return true;
//        } else {
//            return isPermittedWithNotOperator(principals, permission);
//        }
//    }

//    private boolean isPermittedWithNotOperator(PrincipalCollection principals, String permission) {
//        if (permission.startsWith(NOT_OPERATOR)) {
//            return !super.isPermitted(principals, permission.substring(NOT_OPERATOR.length()));
//        } else {
//            return super.isPermitted(principals, permission);
//        }
//    }
    
    protected TicketValidator ensureTicketValidator() {
        if (this.ticketValidator == null) {
            this.ticketValidator = createTicketValidator();
        }
        return this.ticketValidator;
    }
    
    protected TicketValidator createTicketValidator() {
        String urlPrefix = getCasServerUrlPrefix();
        if ("saml".equalsIgnoreCase(getValidationProtocol())) {
            return new Saml11TicketValidator(urlPrefix);
        }
        return new Cas20ServiceTicketValidator(urlPrefix);
    }
    @Async
    @EventListener  
    public void handleAuthChangeEvent(AuthChangeEvent authorEvent) { 
    	AuthCacheClearType cacheClearType=authorEvent.getCacheType();
    	Subject subject=authorEvent.getSubject();
    	Session session = subject.getSession();
    	switch (cacheClearType) {
		case allCachedAuthenticationInfo:
			clearAllCachedAuthenticationInfo();
			break;
		case cachedAuthorizationInfoBySubject:
			clearCachedAuthorizationInfo(authorEvent.getSubject());
			authorizationVersionPlus(session);
			break;
		case cachedAuthorizationInfoByRole:
			clearAllCachedAuthorizationInfo();
			clearCachedAuthorizationInfoByRole(authorEvent.getRoleCode());
			authorizationVersionPlus(session);
			break;
		case allCachedAuthorizationInfo:
			clearAllCachedAuthorizationInfo();
			authorizationVersionPlus(session);
			break;
		case allCache:
			clearAllCachedAuthenticationInfo();
			clearAllCachedAuthorizationInfo();
			authorizationVersionPlus(session);
			break;
		default:
			break;
		}
    }  
  //权限变化，cache版本升级
    private void authorizationVersionPlus(Session session){
    	Long authorizationVersion=(Long) session.getAttribute(appKey+AuthConsts.authorCacheVersion);
    	if(authorizationVersion==null){
    		session.setAttribute(appKey+AuthConsts.authorCacheVersion, 1L);
    	}else{
    		session.setAttribute(appKey+AuthConsts.authorCacheVersion, new Long(authorizationVersion+1));
    	}
    }
	/**
	 * 清除用户认证信息缓存.
	 * @param subject
	 */
	public void clearCachedAuthenticationInfo(Subject subject) {
	   if(subject!=null&&subject.getPrincipals()!=null){
	        super.clearCachedAuthenticationInfo(subject.getPrincipals());
	    }
	}
	/**
	 * 清除用户认证信息缓存.
	 * @param subject
	 */
	public void clearAllCachedAuthenticationInfo() {
		Cache<Object, AuthenticationInfo> cache = getAuthenticationCache();
		 if (cache != null&&cache.size()>0) {
	        Set<?> keys=cache.keys();
	            for (Object key : keys) {
	                cache.remove(key);
	            }
	        }
	}
	
	/**
     * 清除用户授权信息缓存.
     */
    public void clearCachedAuthorizationInfo(Subject subject) {
    	if(subject!=null&&subject.getPrincipals()!=null){
           super.clearCachedAuthorizationInfo(subject.getPrincipals());
    	}
    }
    /**    
     * 
     * @param subject
     */
    public void clearCachedAuthorizationInfoByRole(String roleCode) {
    	if(ZStrUtil.hasText(roleCode)){
    		 Cache<Object, AuthorizationInfo> cache = getAuthorizationCache();
    		if (cache != null&&cache.size()>0) {
    			Set<?> keys=cache.keys();
                for (Object key : keys) {
                	SimpleAuthorizationInfo authorizationInfo =(SimpleAuthorizationInfo) cache.get(key);
                	Set<String> roleSet=authorizationInfo.getRoles();
                	if(roleSet!=null&&roleSet.contains(roleCode)){
                		cache.remove(key);
                	}
             	  
                }
    		}
    	}
    }
    /**
     * 清除所有用户授权信息缓存.
     */
    public void clearAllCachedAuthorizationInfo() {
        Cache<Object, AuthorizationInfo> cache = getAuthorizationCache();
        if (cache != null&&cache.size()>0) {
        	Set<?> keys=cache.keys();
            for (Object key : keys) {
                cache.remove(key);
            }
        }
    }
	public String getAppKey() {
		return appKey;
	}
	public void setAppKey(String appKey) {
		this.appKey = appKey;
	}
	
	public IAuthManager getAuthManager() {
		return authManager;
	}
	public void setAuthManager(IAuthManager authManager) {
		this.authManager = authManager;
	}
	public TicketValidator getTicketValidator() {
		return ticketValidator;
	}
	public void setTicketValidator(TicketValidator ticketValidator) {
		this.ticketValidator = ticketValidator;
	}
	public String getCasServerUrlPrefix() {
		return casServerUrlPrefix;
	}
	public void setCasServerUrlPrefix(String casServerUrlPrefix) {
		this.casServerUrlPrefix = casServerUrlPrefix;
	}
	public String getCasService() {
		return casService;
	}
	public void setCasService(String casService) {
		this.casService = casService;
	}

	public String getValidationProtocol() {
		return validationProtocol;
	}

	public void setValidationProtocol(String validationProtocol) {
		this.validationProtocol = validationProtocol;
	}
 
}
